{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use GNN to predict ADMET Property\n",
    "\n",
    "Demonstration on Submitting to TDC ADMET Caco2_Wang Benchmark"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 1: Load the benchmark dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found local copy...\n"
     ]
    }
   ],
   "source": [
    "from tdc import BenchmarkGroup\n",
    "group = BenchmarkGroup(name = 'ADMET_Group', path = 'data/')\n",
    "benchmark = group.get('Caco2_Wang')\n",
    "\n",
    "train_val, test = benchmark['train_val'], benchmark['test']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 2: Train Your Models With Five Runs\n",
    "\n",
    "We use [DeepPurpose](https://github.com/kexinhuang12345/DeepPurpose), a sklearn-style deep learning for drug discovery library as an example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "generating training, validation splits...\n",
      "100%|██████████| 728/728 [00:00<00:00, 1355.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Drug Property Prediction Mode...\n",
      "in total: 637 drugs\n",
      "encoding drug...\n",
      "unique drugs: 634\n",
      "do not do train/test split on the data for already splitted data\n",
      "Drug Property Prediction Mode...\n",
      "in total: 91 drugs\n",
      "encoding drug...\n",
      "unique drugs: 91\n",
      "do not do train/test split on the data for already splitted data\n",
      "Drug Property Prediction Mode...\n",
      "in total: 182 drugs\n",
      "encoding drug...\n",
      "unique drugs: 181\n",
      "do not do train/test split on the data for already splitted data\n",
      "predicting...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "generating training, validation splits...\n",
      "100%|██████████| 728/728 [00:00<00:00, 1300.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Drug Property Prediction Mode...\n",
      "in total: 637 drugs\n",
      "encoding drug...\n",
      "unique drugs: 635\n",
      "do not do train/test split on the data for already splitted data\n",
      "Drug Property Prediction Mode...\n",
      "in total: 91 drugs\n",
      "encoding drug...\n",
      "unique drugs: 90\n",
      "do not do train/test split on the data for already splitted data\n",
      "Drug Property Prediction Mode...\n",
      "in total: 182 drugs\n",
      "encoding drug...\n",
      "unique drugs: 181\n",
      "do not do train/test split on the data for already splitted data\n",
      "predicting...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "generating training, validation splits...\n",
      "100%|██████████| 728/728 [00:00<00:00, 1245.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Drug Property Prediction Mode...\n",
      "in total: 637 drugs\n",
      "encoding drug...\n",
      "unique drugs: 634\n",
      "do not do train/test split on the data for already splitted data\n",
      "Drug Property Prediction Mode...\n",
      "in total: 91 drugs\n",
      "encoding drug...\n",
      "unique drugs: 91\n",
      "do not do train/test split on the data for already splitted data\n",
      "Drug Property Prediction Mode...\n",
      "in total: 182 drugs\n",
      "encoding drug...\n",
      "unique drugs: 181\n",
      "do not do train/test split on the data for already splitted data\n",
      "predicting...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "generating training, validation splits...\n",
      "100%|██████████| 728/728 [00:00<00:00, 1257.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Drug Property Prediction Mode...\n",
      "in total: 637 drugs\n",
      "encoding drug...\n",
      "unique drugs: 634\n",
      "do not do train/test split on the data for already splitted data\n",
      "Drug Property Prediction Mode...\n",
      "in total: 91 drugs\n",
      "encoding drug...\n",
      "unique drugs: 91\n",
      "do not do train/test split on the data for already splitted data\n",
      "Drug Property Prediction Mode...\n",
      "in total: 182 drugs\n",
      "encoding drug...\n",
      "unique drugs: 181\n",
      "do not do train/test split on the data for already splitted data\n",
      "predicting...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "generating training, validation splits...\n",
      "100%|██████████| 728/728 [00:00<00:00, 1274.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Drug Property Prediction Mode...\n",
      "in total: 637 drugs\n",
      "encoding drug...\n",
      "unique drugs: 635\n",
      "do not do train/test split on the data for already splitted data\n",
      "Drug Property Prediction Mode...\n",
      "in total: 91 drugs\n",
      "encoding drug...\n",
      "unique drugs: 90\n",
      "do not do train/test split on the data for already splitted data\n",
      "Drug Property Prediction Mode...\n",
      "in total: 182 drugs\n",
      "encoding drug...\n",
      "unique drugs: 181\n",
      "do not do train/test split on the data for already splitted data\n",
      "predicting...\n"
     ]
    }
   ],
   "source": [
    "from DeepPurpose import CompoundPred as models\n",
    "from DeepPurpose.utils import data_process, generate_config\n",
    "\n",
    "drug_encoding = 'MPNN'\n",
    "prediction_runs = []\n",
    "\n",
    "for seed in [1, 2, 3, 4, 5]:\n",
    "    ### Generate Different Train, Valid Splits Given Seed ###\n",
    "    train, valid = group.get_train_valid_split(benchmark = name, split_type = 'default', seed = seed)\n",
    "    \n",
    "    ### Train the Model on Train, Valid Set ###\n",
    "    train = data_process(X_drug = train.Drug.values, y = train.Y.values, drug_encoding = drug_encoding, split_method='no_split')\n",
    "    val = data_process(X_drug = valid.Drug.values, y = valid.Y.values, drug_encoding = drug_encoding, split_method='no_split')\n",
    "    test = data_process(X_drug = benchmark['test'].Drug.values, y = benchmark['test'].Y.values, drug_encoding = drug_encoding, split_method='no_split')\n",
    "\n",
    "    config = generate_config(drug_encoding = drug_encoding, train_epoch = 10, LR = 0.001, batch_size = 128)\n",
    "    model = models.model_initialize(**config)\n",
    "    model.train(train, val, test, verbose = False)\n",
    "    \n",
    "    ### Generate Predictions on the Test Set ###\n",
    "    y_pred = model.predict(test)\n",
    "    prediction_runs.append({benchmark['name']: y_pred})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Evaluate the testing set prediction with pre-specified TDC evaluator\n",
    "\n",
    "The mean and standard deviation of the model performances are generated."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'caco2_wang': [0.64, 0.028]}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "group.evaluate_many(prediction_runs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4: Copy the above results and submit to [THIS FORM](https://forms.gle/HYupGaV7WDuutbr9A).\n",
    "\n",
    "## That's it! Your results will be reflected on the [leaderboard website](https://tdcommons.ai/benchmark/admet_group/01caco2/) soon!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:DeepPurpose]",
   "language": "python",
   "name": "conda-env-DeepPurpose-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
